{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Otter Image Demo (In-context Learning)\n",
    "\n",
    "Here is an example of multi-modal ICL (in-context learning) with 🦦 Otter. We provide two demo images with corresponding instructions and answers, then we ask the model to generate an answer given our instruct. You may change your instruction and see how the model responds.\n",
    "\n",
    "You can also try our [online demo](https://otter.cliangyu.com/) to see more in-context learning demonstrations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import requests\n",
    "import torch\n",
    "import transformers\n",
    "from PIL import Image\n",
    "import sys\n",
    "\n",
    "sys.path.append(\"../../src\")\n",
    "from otter_ai import OtterForConditionalGeneration\n",
    "\n",
    "model = OtterForConditionalGeneration.from_pretrained(\"luodian/OTTER-9B-LA-InContext\", device_map=\"auto\")\n",
    "tokenizer = model.text_tokenizer\n",
    "image_processor = transformers.CLIPImageProcessor()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "demo_image_one = Image.open(requests.get(\"http://images.cocodataset.org/val2017/000000039769.jpg\", stream=True).raw)\n",
    "demo_image_two = Image.open(requests.get(\"http://images.cocodataset.org/test-stuff2017/000000028137.jpg\", stream=True).raw)\n",
    "query_image = Image.open(requests.get(\"http://images.cocodataset.org/test-stuff2017/000000028352.jpg\", stream=True).raw)\n",
    "vision_x = image_processor.preprocess([demo_image_one, demo_image_two, query_image], return_tensors=\"pt\")[\"pixel_values\"].unsqueeze(1).unsqueeze(0)\n",
    "model.text_tokenizer.padding_side = \"left\"\n",
    "lang_x = model.text_tokenizer(\n",
    "    [\"<image>User: a photo of GPT:<answer> two cats sleeping.<|endofchunk|><image>User: a photo of GPT:<answer> a bathroom sink.<|endofchunk|><image>User: a photo of GPT:<answer>\"],\n",
    "    return_tensors=\"pt\",\n",
    ")\n",
    "\n",
    "# Get the data type from model's parameters\n",
    "model_dtype = next(model.parameters()).dtype\n",
    "\n",
    "# Convert tensors to the model's data type\n",
    "vision_x = vision_x.to(dtype=model_dtype)\n",
    "lang_x_input_ids = lang_x[\"input_ids\"]\n",
    "lang_x_attention_mask = lang_x[\"attention_mask\"]\n",
    "\n",
    "bad_words_id = model.text_tokenizer([\"User:\", \"GPT1:\", \"GFT:\", \"GPT:\"], add_special_tokens=False).input_ids\n",
    "generated_text = model.generate(\n",
    "    vision_x=vision_x.to(model.device),\n",
    "    lang_x=lang_x_input_ids.to(model.device),\n",
    "    attention_mask=lang_x_attention_mask.to(model.device),\n",
    "    max_new_tokens=512,\n",
    "    num_beams=3,\n",
    "    no_repeat_ngram_size=3,\n",
    "    bad_words_ids=bad_words_id,\n",
    ")\n",
    "\n",
    "parsed_output = model.text_tokenizer.decode(generated_text[0]).split(\"<answer>\")[-1].lstrip().rstrip().split(\"<|endofchunk|>\")[0].lstrip().rstrip().lstrip('\"').rstrip('\"')\n",
    "\n",
    "print(\"Generated text: \", parsed_output)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "otter",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
